import argparse

import yaml
from timm.utils import add_bool_arg


def get_args(config_file=None):
    def _parse_args():
        if config_file:
            with open(config_file, "r") as f:
                cfg = yaml.safe_load(f)
                parser.set_defaults(**cfg)

        # There may be remaining unrecognized options
        # The main arg parser parses the rest of the args, the usual
        # defaults will have been overridden if config file specified.
        args, _ = parser.parse_known_args()

        # Cache the args as a text string to save them in the output dir later
        args_text = yaml.safe_dump(args.__dict__, default_flow_style=False)
        return args, args_text

    # The first arg parser parses out only the --config argument, this argument is used to
    # load a yaml file containing key-values that override the defaults for the main parser below
    parser = argparse.ArgumentParser(description="Training Config", add_help=False)
    parser.add_argument(
        "-c",
        "--config",
        default="",
        type=str,
        metavar="FILE",
        help="YAML config file specifying default arguments",
    )

    parser = argparse.ArgumentParser(description="PyTorch ImageNet Training")
    # Dataset / Model parameters
    # parser.add_argument('root', metavar='DIR',
    #                     help='path to dataset')
    parser.add_argument(
        "--dataset",
        default="coco",
        type=str,
        metavar="DATASET",
        help='Name of dataset to train (default: "coco"',
    )
    parser.add_argument(
        "--model",
        default="tf_efficientdet_d1",
        type=str,
        metavar="MODEL",
        help='Name of model to train (default: "tf_efficientdet_d1"',
    )
    add_bool_arg(
        parser,
        "redundant-bias",
        default=None,
        help="override model config for redundant bias",
    )
    add_bool_arg(
        parser, "soft-nms", default=None, help="override model config for soft-nms"
    )
    parser.add_argument(
        "--val-skip",
        type=int,
        default=0,
        metavar="N",
        help="Skip every N validation samples.",
    )
    parser.add_argument(
        "--num-classes",
        type=int,
        default=None,
        metavar="N",
        help="Override num_classes in model config if set. For fine-tuning from pretrained.",
    )
    parser.add_argument(
        "--pretrained",
        action="store_true",
        default=False,
        help="Start with pretrained version of specified network (if avail)",
    )
    parser.add_argument(
        "--no-pretrained-backbone",
        action="store_true",
        default=False,
        help="Do not start with pretrained backbone weights, fully random.",
    )
    parser.add_argument(
        "--initial-checkpoint",
        default="",
        type=str,
        metavar="PATH",
        help="Initialize model from this checkpoint (default: none)",
    )
    parser.add_argument(
        "--resume",
        default="",
        type=str,
        metavar="PATH",
        help="Resume full model and optimizer state from checkpoint (default: none)",
    )
    parser.add_argument(
        "--no-resume-opt",
        action="store_true",
        default=False,
        help="prevent resume of optimizer state when resuming model",
    )
    parser.add_argument(
        "--mean",
        type=float,
        nargs="+",
        default=None,
        metavar="MEAN",
        help="Override mean pixel value of dataset",
    )
    parser.add_argument(
        "--std",
        type=float,
        nargs="+",
        default=None,
        metavar="STD",
        help="Override std deviation of of dataset",
    )
    parser.add_argument(
        "--interpolation",
        default="",
        type=str,
        metavar="NAME",
        help="Image resize interpolation type (overrides model)",
    )
    parser.add_argument(
        "--fill-color",
        default=None,
        type=str,
        metavar="NAME",
        help='Image augmentation fill (background) color ("mean" or int)',
    )
    parser.add_argument(
        "--batch-size",
        type=int,
        default=32,
        metavar="N",
        help="input batch size for training (default: 32)",
    )
    parser.add_argument(
        "--clip-grad",
        type=float,
        default=10.0,
        metavar="NORM",
        help="Clip gradient norm (default: 10.0)",
    )

    # Optimizer parameters
    parser.add_argument(
        "--opt",
        default="momentum",
        type=str,
        metavar="OPTIMIZER",
        help='Optimizer (default: "momentum"',
    )
    parser.add_argument(
        "--opt-eps",
        default=1e-3,
        type=float,
        metavar="EPSILON",
        help="Optimizer Epsilon (default: 1e-3)",
    )
    parser.add_argument(
        "--momentum",
        type=float,
        default=0.9,
        metavar="M",
        help="SGD momentum (default: 0.9)",
    )
    parser.add_argument(
        "--weight-decay",
        type=float,
        default=4e-5,
        help="weight decay (default: 0.00004)",
    )

    # Learning rate schedule parameters
    parser.add_argument(
        "--sched",
        default="cosine",
        type=str,
        metavar="SCHEDULER",
        help='LR scheduler (default: "step"',
    )
    parser.add_argument(
        "--lr",
        type=float,
        default=0.01,
        metavar="LR",
        help="learning rate (default: 0.01)",
    )
    parser.add_argument(
        "--lr-noise",
        type=float,
        nargs="+",
        default=None,
        metavar="pct, pct",
        help="learning rate noise on/off epoch percentages",
    )
    parser.add_argument(
        "--lr-noise-pct",
        type=float,
        default=0.67,
        metavar="PERCENT",
        help="learning rate noise limit percent (default: 0.67)",
    )
    parser.add_argument(
        "--lr-noise-std",
        type=float,
        default=1.0,
        metavar="STDDEV",
        help="learning rate noise std-dev (default: 1.0)",
    )
    parser.add_argument(
        "--lr-cycle-mul",
        type=float,
        default=1.0,
        metavar="MULT",
        help="learning rate cycle len multiplier (default: 1.0)",
    )
    parser.add_argument(
        "--lr-cycle-limit",
        type=int,
        default=1,
        metavar="N",
        help="learning rate cycle limit",
    )
    parser.add_argument(
        "--warmup-lr",
        type=float,
        default=0.0001,
        metavar="LR",
        help="warmup learning rate (default: 0.0001)",
    )
    parser.add_argument(
        "--min-lr",
        type=float,
        default=1e-5,
        metavar="LR",
        help="lower lr bound for cyclic schedulers that hit 0 (1e-5)",
    )
    parser.add_argument(
        "--epochs",
        type=int,
        default=300,
        metavar="N",
        help="number of epochs to train (default: 2)",
    )
    parser.add_argument(
        "--start-epoch",
        default=None,
        type=int,
        metavar="N",
        help="manual epoch number (useful on restarts)",
    )
    parser.add_argument(
        "--decay-epochs",
        type=float,
        default=30,
        metavar="N",
        help="epoch interval to decay LR",
    )
    parser.add_argument(
        "--warmup-epochs",
        type=int,
        default=5,
        metavar="N",
        help="epochs to warmup LR, if scheduler supports",
    )
    parser.add_argument(
        "--cooldown-epochs",
        type=int,
        default=10,
        metavar="N",
        help="epochs to cooldown LR at min_lr, after cyclic schedule ends",
    )
    parser.add_argument(
        "--patience-epochs",
        type=int,
        default=10,
        metavar="N",
        help="patience epochs for Plateau LR scheduler (default: 10",
    )
    parser.add_argument(
        "--decay-rate",
        "--dr",
        type=float,
        default=0.1,
        metavar="RATE",
        help="LR decay rate (default: 0.1)",
    )

    # Augmentation parameters
    parser.add_argument(
        "--color-jitter",
        type=float,
        default=0.4,
        metavar="PCT",
        help="Color jitter factor (default: 0.4)",
    )
    (
        parser.add_argument(
            "--aa",
            type=str,
            default=None,
            metavar="NAME",
            help='Use AutoAugment policy. "v0" or "original". (default: None)',
        ),
    )
    parser.add_argument(
        "--reprob",
        type=float,
        default=0.0,
        metavar="PCT",
        help="Random erase prob (default: 0.)",
    )
    parser.add_argument(
        "--remode",
        type=str,
        default="pixel",
        help='Random erase mode (default: "pixel")',
    )
    parser.add_argument(
        "--recount", type=int, default=1, help="Random erase count (default: 1)"
    )
    parser.add_argument(
        "--train-interpolation",
        type=str,
        default="random",
        help='Training interpolation (random, bilinear, bicubic default: "random")',
    )

    # loss
    parser.add_argument(
        "--smoothing",
        type=float,
        default=None,
        help="override model config label smoothing",
    )
    add_bool_arg(
        parser,
        "jit-loss",
        default=None,
        help="override model config for torchscript jit loss fn",
    )
    add_bool_arg(
        parser,
        "legacy-focal",
        default=None,
        help="override model config to use legacy focal loss",
    )

    # Model Exponential Moving Average
    parser.add_argument(
        "--model-ema",
        action="store_true",
        default=False,
        help="Enable tracking moving average of model weights",
    )
    parser.add_argument(
        "--model-ema-decay",
        type=float,
        default=0.9998,
        help="decay factor for model weights moving average (default: 0.9998)",
    )

    # Misc
    parser.add_argument(
        "--sync-bn",
        action="store_true",
        help="Enable NVIDIA Apex or Torch synchronized BatchNorm.",
    )
    parser.add_argument(
        "--dist-bn",
        type=str,
        default="",
        help='Distribute BatchNorm stats between nodes after each epoch ("broadcast", "reduce", or "")',
    )
    parser.add_argument(
        "--seed", type=int, default=42, metavar="S", help="random seed (default: 42)"
    )
    parser.add_argument(
        "--log-interval",
        type=int,
        default=50,
        metavar="N",
        help="how many batches to wait before logging training status",
    )
    parser.add_argument(
        "--recovery-interval",
        type=int,
        default=0,
        metavar="N",
        help="how many batches to wait before writing recovery checkpoint",
    )
    parser.add_argument(
        "-j",
        "--workers",
        type=int,
        default=0,
        metavar="N",
        help="how many training processes to use (default: 0)",
    )
    parser.add_argument(
        "--save-images",
        action="store_true",
        default=False,
        help="save images of input bathes every log interval for debugging",
    )
    parser.add_argument(
        "--amp",
        action="store_true",
        default=False,
        help="use NVIDIA Apex AMP or Native AMP for mixed precision training",
    )
    parser.add_argument(
        "--apex-amp",
        action="store_true",
        default=False,
        help="Use NVIDIA Apex AMP mixed precision",
    )
    parser.add_argument(
        "--native-amp",
        action="store_true",
        default=False,
        help="Use Native Torch AMP mixed precision",
    )
    parser.add_argument(
        "--channels-last",
        action="store_true",
        default=False,
        help="Use channels_last memory layout",
    )
    parser.add_argument(
        "--pin-mem",
        action="store_true",
        default=False,
        help="Pin CPU memory in DataLoader for more efficient (sometimes) transfer to GPU.",
    )
    parser.add_argument(
        "--no-prefetcher",
        action="store_true",
        default=False,
        help="disable fast prefetcher",
    )
    parser.add_argument(
        "--torchscript",
        dest="torchscript",
        action="store_true",
        help="convert model torchscript for inference",
    )
    add_bool_arg(
        parser,
        "bench-labeler",
        default=False,
        help="label targets in model bench, increases GPU load at expense of loader processes",
    )
    parser.add_argument(
        "--output",
        default="",
        type=str,
        metavar="PATH",
        help="path to output folder (default: none, current dir)",
    )
    parser.add_argument(
        "--eval-metric",
        default="map",
        type=str,
        metavar="EVAL_METRIC",
        help='Best metric (default: "map"',
    )
    parser.add_argument(
        "--tta",
        type=int,
        default=0,
        metavar="N",
        help="Test/inference time augmentation (oversampling) factor. 0=None (default: 0)",
    )
    parser.add_argument("--local_rank", default=0, type=int)

    # Evaluation parameters
    parser.add_argument(
        "--eval-interpolation",
        default="bilinear",
        type=str,
        metavar="NAME",
        help="Image resize interpolation type (overrides model)",
    )
    parser.add_argument(
        "--img-size",
        default=None,
        type=int,
        metavar="N",
        help="Input image dimension, uses model default if empty",
    )
    parser.add_argument(
        "--checkpoint",
        default="",
        type=str,
        metavar="PATH",
        help="path to latest checkpoint (default: none)",
    )
    parser.add_argument(
        "--use-ema",
        dest="use_ema",
        action="store_true",
        help="use ema version of weights if present",
    )

    args, _ = _parse_args()
    return args
